{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 2: Arithmetic Secret Sharing\n",
    "Arithmetic secret sharing is mainly used in secure two-party computation, where each participant holds the shared value of the data. In this way the data does not leak information during the calculation process. At present, our model and functions are designed based on semi-honest parties.\n",
    "To use arithmetic secret sharing for secure two-party computation, we import the following packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:54.514628Z",
     "start_time": "2024-03-21T11:20:52.309934Z"
    }
   },
   "outputs": [],
   "source": [
    "# import the libraries\n",
    "from NssMPC.secure_model.mpc_party.semi_honest import SemiHonestCS\n",
    "from NssMPC import ArithmeticSecretSharing\n",
    "from NssMPC.common.ring.ring_tensor import RingTensor\n",
    "from NssMPC.crypto.aux_parameter.beaver_triples.arithmetic_triples import MatmulTriples\n",
    "\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```SemiHonestCS``` is the two semi-honest party. ```ArithmeticSecretSharing``` is the main package that we use. ```RingTensor``` is the main data structure that we use. ```BeaverProvider``` is the triple provider we use in the arithmetic secret share for multiplication operations, and we use ```BeaverProvider``` to simulate a trusted third party to provide auxiliary operation data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Party\n",
    "First, we need to define the parties involved in the computation. For secure two-party computation, we need two parties: the server and the client.\n",
    "When setting up the parties, we need to specify the address and port for each party. Each party has a tcp server and a tcp client. They all need an address and a port. We also need to set the Beaver triple provider and the wrap provider for the computations. If you are planning to do comparison operations, do not forget to set the compare key provider.\n",
    "In this demonstration we are using multi-threading to simulate two parties. In real applications, the server and client run in two files. You can refer to ``./debug/crypto/primitives/arithmetic_secret_sharing/test_ass_server.py`` and ```./ debug/crypto/primitives/arithmetic_secret_sharing/test_ass_client.py```."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:55.113313Z",
     "start_time": "2024-03-21T11:20:54.515627Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TCPServer waiting for connection ......\n",
      "TCPServer waiting for connection ......\n",
      "successfully connect to server 127.0.0.1:8000\n",
      "TCPServer successfully connected by :('127.0.0.1', 9100)\n",
      "successfully connect to server 127.0.0.1:9000\n",
      "TCPServer successfully connected by :('127.0.0.1', 8200)\n"
     ]
    }
   ],
   "source": [
    "import threading\n",
    "\n",
    "# set Server\n",
    "server = SemiHonestCS(type='server')\n",
    "\n",
    "server.set_multiplication_provider()\n",
    "server.set_comparison_provider()\n",
    "server.set_nonlinear_operation_provider()\n",
    "\n",
    "def set_server():\n",
    "    # CS connect\n",
    "    server.online()\n",
    "\n",
    "# set Client\n",
    "client = SemiHonestCS(type='client')\n",
    "\n",
    "client.set_multiplication_provider()\n",
    "client.set_comparison_provider()\n",
    "client.set_nonlinear_operation_provider()\n",
    "\n",
    "def set_client():\n",
    "    # CS connect\n",
    "    client.online()\n",
    "\n",
    "\n",
    "server_thread = threading.Thread(target=set_server)\n",
    "client_thread = threading.Thread(target=set_client)\n",
    "\n",
    "server_thread.start()\n",
    "client_thread.start()\n",
    "client_thread.join()\n",
    "server_thread.join()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you see two instances of \"successfully connected\", it indicates that the communication between the two parties has been established successfully."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Secret Sharing\n",
    "If both parties have data that they want to compute on without revealing their individual data to each other, you can use the ```share``` method from ```ArithmeticSecretSharing``` (ASS) to perform data sharing. Additionally, you need to utilize TCP to send each party's shares to the other party and receive their own shares.\n",
    "In this case, let's assume that the server has data denoted as x, and the client has data denoted as y."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:55.284279Z",
     "start_time": "2024-03-21T11:20:55.115373Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "shared x in client: shared x in server:   ArithmeticSecretSharing[\n",
      "RingTensor\n",
      " value:tensor([[-5023461124792412191, -1673044175232054344],\n",
      "        [-1521313688647088049, -4031918201498529903]]) \n",
      " dtype:float \n",
      " scale:65536\n",
      " party:1\n",
      "]\n",
      "shared y in client:  ArithmeticSecretSharing[\n",
      "RingTensor\n",
      " value:tensor([[5023461124792477727, 1673044175232185416],\n",
      "        [1521313688647284657, 4031918201498792047]]) \n",
      " dtype:float \n",
      " scale:65536\n",
      " party:0\n",
      "]\n",
      "shared y in server:  ArithmeticSecretSharing[\n",
      "RingTensor\n",
      " value:tensor([[-6043862842016988588, -1290222660339364218],\n",
      "        [ 2137132316675903433,  4903197941973986200]]) \n",
      " dtype:float \n",
      " scale:65536\n",
      " party:1\n",
      "]\n",
      "ArithmeticSecretSharing[\n",
      "RingTensor\n",
      " value:tensor([[ 6043862842016923052,  1290222660339495290],\n",
      "        [-2137132316675641289, -4903197941973789592]]) \n",
      " dtype:float \n",
      " scale:65536\n",
      " party:0\n",
      "]\n"
     ]
    }
   ],
   "source": [
    "from NssMPC.config.configs import DEVICE\n",
    "\n",
    "# data belong to server\n",
    "x = RingTensor.convert_to_ring(torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=DEVICE))\n",
    "# data belong to client\n",
    "y = RingTensor.convert_to_ring(torch.tensor([[-1.0, 2.0], [4.0, 3.0]], device=DEVICE))\n",
    "\n",
    "# split x into 2 parts\n",
    "X = ArithmeticSecretSharing.share(x, 2)\n",
    "\n",
    "# split y into 2 parts\n",
    "Y = ArithmeticSecretSharing.share(y, 2)\n",
    "\n",
    "temp_shared_x0=ArithmeticSecretSharing(X[0].ring_tensor,server)\n",
    "temp_shared_x1=ArithmeticSecretSharing(X[1].ring_tensor,client)\n",
    "temp_shared_y0=ArithmeticSecretSharing(Y[0].ring_tensor,server)\n",
    "temp_shared_y1=ArithmeticSecretSharing(Y[1].ring_tensor,client)\n",
    "\n",
    "def server_action():\n",
    "    # server shares x1 to client\n",
    "    server.send(X[1])\n",
    "    shared_x_0 = ArithmeticSecretSharing(X[0].ring_tensor,server)\n",
    "    # server receives y0 from client\n",
    "    y0 = server.receive()\n",
    "    shared_y_0 = ArithmeticSecretSharing(y0.ring_tensor,server)\n",
    "    print(\"shared x in server: \", shared_x_0)\n",
    "    print(\"shared y in server: \", shared_y_0)\n",
    "\n",
    "def client_action():\n",
    "    # client receives x1 from server\n",
    "    x1 = client.receive()\n",
    "    # client shares y0 to server\n",
    "    client.send(Y[0])\n",
    "    shared_x_1 = ArithmeticSecretSharing(x1.ring_tensor,client)\n",
    "    shared_y_1 = ArithmeticSecretSharing(Y[1].ring_tensor,client)\n",
    "    print(\"shared x in client: \", shared_x_1)\n",
    "    print(\"shared y in client: \", shared_y_1)\n",
    "\n",
    "server_thread = threading.Thread(target=server_action)\n",
    "client_thread = threading.Thread(target=client_action)\n",
    "\n",
    "server_thread.start()\n",
    "client_thread.start()\n",
    "client_thread.join()\n",
    "server_thread.join()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Secret Restoring\n",
    "If you want to restore the original value by the share, you can use the ```restore()``` method, which returns a ```RingTensor``` value, and then the ```convert_to_real_field``` can recover the result.\n",
    "In this tutorial, we only print the recovered results on the server side."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:55.299955Z",
     "start_time": "2024-03-21T11:20:55.286367Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "temp_shared_x0 ArithmeticSecretSharing[\n",
      "RingTensor\n",
      " value:tensor([[5023461124792477727, 1673044175232185416],\n",
      "        [1521313688647284657, 4031918201498792047]]) \n",
      " dtype:float \n",
      " scale:65536\n",
      " party:0\n",
      "]\n",
      "\n",
      " x after restoring: tensor([[1., 2.],\n",
      "        [3., 4.]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# restore share_x\n",
    "# server\n",
    "\n",
    "print(\"temp_shared_x0\",temp_shared_x0)\n",
    "def restore_server():\n",
    "    restored_x = temp_shared_x0.restore()\n",
    "    real_x = restored_x.convert_to_real_field()\n",
    "    print(\"\\n x after restoring:\", real_x)\n",
    "\n",
    "# client\n",
    "def restore_client():\n",
    "    temp_shared_x1.restore()\n",
    "\n",
    "server_thread = threading.Thread(target=restore_server)\n",
    "client_thread = threading.Thread(target=restore_client)\n",
    "\n",
    "server_thread.start()\n",
    "client_thread.start()\n",
    "client_thread.join()\n",
    "server_thread.join()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Operations\n",
    "Next, we'll show you how to use arithmetic secret sharing to achieve secure two-party computation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Arithmetic Operations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:55.315597Z",
     "start_time": "2024-03-21T11:20:55.302043Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Addition tensor([[0., 4.],\n",
      "        [7., 7.]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# Addition\n",
    "# restore result\n",
    "def addition_server():\n",
    "    res_0 = temp_shared_x0 + temp_shared_y0\n",
    "    result_restored = res_0.restore().convert_to_real_field()\n",
    "    print(\"\\nAddition\", result_restored)\n",
    "\n",
    "def addition_client():\n",
    "    res_1 = temp_shared_x1 + temp_shared_y1\n",
    "    res_1.restore()\n",
    "\n",
    "server_thread = threading.Thread(target=addition_server)\n",
    "client_thread = threading.Thread(target=addition_client)\n",
    "\n",
    "server_thread.start()\n",
    "client_thread.start()\n",
    "client_thread.join()\n",
    "server_thread.join()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:55.331058Z",
     "start_time": "2024-03-21T11:20:55.316607Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Subtraction tensor([[ 2.,  0.],\n",
      "        [-1.,  1.]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# Subtraction\n",
    "# restore result\n",
    "def subtraction_server():\n",
    "    res_0 = temp_shared_x0 - temp_shared_y0\n",
    "    result_restored = res_0.restore().convert_to_real_field()\n",
    "    print(\"\\nSubtraction\", result_restored)\n",
    "\n",
    "def subtraction_client():\n",
    "    res_1 = temp_shared_x1 - temp_shared_y1\n",
    "    res_1.restore()\n",
    "\n",
    "server_thread = threading.Thread(target=subtraction_server)\n",
    "client_thread = threading.Thread(target=subtraction_client)\n",
    "\n",
    "server_thread.start()\n",
    "client_thread.start()\n",
    "client_thread.join()\n",
    "server_thread.join()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:55.362370Z",
     "start_time": "2024-03-21T11:20:55.332115Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " Multiplication tensor([[-1.,  4.],\n",
      "        [12., 12.]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# Multiplication\n",
    "# restore result\n",
    "def multiplication_server():\n",
    "    res_0 = temp_shared_x0 * temp_shared_y0\n",
    "    result_restored = res_0.restore().convert_to_real_field()\n",
    "    print(\"\\n Multiplication\", result_restored)\n",
    "\n",
    "def multiplication_client():\n",
    "    res_1 = temp_shared_x1 * temp_shared_y1\n",
    "    res_1.restore()\n",
    "\n",
    "server_thread = threading.Thread(target=multiplication_server)\n",
    "client_thread = threading.Thread(target=multiplication_client)\n",
    "\n",
    "server_thread.start()\n",
    "client_thread.start()\n",
    "client_thread.join()\n",
    "server_thread.join()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: Since all the beaver triples used were generated during the offline phase, don't forget to generate the required matrix beaver triples before performing matrix multiplication."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:56.046780Z",
     "start_time": "2024-03-21T11:20:55.363382Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x @ y:  RingTensor\n",
      " value:tensor([[ 458752,  524288],\n",
      "        [ 851968, 1179648]]) \n",
      " dtype:float \n",
      " scale:65536\n",
      "real_field(x @ y):  tensor([[ 7.,  8.],\n",
      "        [13., 18.]], dtype=torch.float64)\n",
      "restored x @ y:  restored x @ y:  tensor([[ 7.0000,  8.0000],\n",
      "        [13.0000, 18.0000]], dtype=torch.float64)\n",
      "tensor([[ 7.0000,  8.0000],\n",
      "        [13.0000, 18.0000]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# Matrix Multiplication\n",
    "from NssMPC.config.configs import DEBUG_LEVEL\n",
    "\n",
    "def server_matrix_multiplication():\n",
    "    # gen beaver triples in advance\n",
    "    if DEBUG_LEVEL != 2:\n",
    "        triples = MatmulTriples.gen(1, x.shape, y.shape)\n",
    "        server.providers[MatmulTriples].param = [triples[0]]\n",
    "        server.send(triples[1])\n",
    "        server.providers[MatmulTriples].load_mat_beaver()\n",
    "\n",
    "    print('x @ y: ', x @ y)\n",
    "    print('real_field(x @ y): ',(x @ y).convert_to_real_field())\n",
    "    share_z = temp_shared_x0 @ temp_shared_y0\n",
    "    res_share_z = share_z.restore().convert_to_real_field()\n",
    "    print('restored x @ y: ', res_share_z)\n",
    "    assert torch.allclose((x @ y).convert_to_real_field(), res_share_z, atol=1e-3, rtol=1e-3) == True\n",
    "\n",
    "def client_matrix_multiplication():\n",
    "    if DEBUG_LEVEL != 2:\n",
    "        client.providers[MatmulTriples].param = [client.receive()]\n",
    "        client.providers[MatmulTriples].load_mat_beaver()\n",
    "\n",
    "    share_z = temp_shared_x1 @ temp_shared_y1\n",
    "    print('restored x @ y: ', share_z.restore().convert_to_real_field())\n",
    "\n",
    "\n",
    "server_thread = threading.Thread(target=server_matrix_multiplication)\n",
    "client_thread = threading.Thread(target=client_matrix_multiplication)\n",
    "\n",
    "server_thread.start()\n",
    "client_thread.start()\n",
    "client_thread.join()\n",
    "server_thread.join()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Comparison Operations\n",
    "The output results ```0``` and ```1``` correspond to the ``False`` and ``True`` values obtained from comparing the sizes of the RingTensors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:56.108906Z",
     "start_time": "2024-03-21T11:20:56.048823Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "(x < y) tensor([[0., 0.],\n",
      "        [1., 0.]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# Server less than\n",
    "def less_than_server():\n",
    "    res_0 = temp_shared_x0 < temp_shared_y0\n",
    "    result_restored = res_0.restore().convert_to_real_field()\n",
    "    print(\"\\n(x < y)\", result_restored)\n",
    "    \n",
    "# Client less than\n",
    "def less_than_client():\n",
    "    res_1 = temp_shared_x1 < temp_shared_y1\n",
    "    res_1.restore()\n",
    "\n",
    "server_thread = threading.Thread(target=less_than_server)\n",
    "client_thread = threading.Thread(target=less_than_client)\n",
    "\n",
    "server_thread.start()\n",
    "client_thread.start()\n",
    "client_thread.join()\n",
    "server_thread.join()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:56.171514Z",
     "start_time": "2024-03-21T11:20:56.109919Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "(x <= y) tensor([[0., 1.],\n",
      "        [1., 0.]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# Less than or equal\n",
    "def less_equal_server():\n",
    "    res_0 = temp_shared_x0 <= temp_shared_y0\n",
    "    result_restored = res_0.restore().convert_to_real_field()\n",
    "    print(\"\\n(x <= y)\", result_restored)\n",
    "\n",
    "def less_equal_client():\n",
    "    res_1 = temp_shared_x1 <= temp_shared_y1\n",
    "    res_1.restore()\n",
    "\n",
    "server_thread = threading.Thread(target=less_equal_server)\n",
    "client_thread = threading.Thread(target=less_equal_client)\n",
    "\n",
    "server_thread.start()\n",
    "client_thread.start()\n",
    "client_thread.join()\n",
    "server_thread.join()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:56.234787Z",
     "start_time": "2024-03-21T11:20:56.173590Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "(x > y) tensor([[1., 0.],\n",
      "        [0., 1.]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# Greater than\n",
    "def greater_than_server():\n",
    "    res_0 = temp_shared_x0 > temp_shared_y0\n",
    "    result_restored = res_0.restore().convert_to_real_field()\n",
    "    print(\"\\n(x > y)\", result_restored)\n",
    "\n",
    "def greater_than_client():\n",
    "    res_1 = temp_shared_x1 > temp_shared_y1\n",
    "    res_1.restore()\n",
    "\n",
    "server_thread = threading.Thread(target=greater_than_server)\n",
    "client_thread = threading.Thread(target=greater_than_client)\n",
    "\n",
    "server_thread.start()\n",
    "client_thread.start()\n",
    "client_thread.join()\n",
    "server_thread.join()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-21T11:20:56.297190Z",
     "start_time": "2024-03-21T11:20:56.235798Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "(x >= y) tensor([[1., 1.],\n",
      "        [0., 1.]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# Greater than or equal\n",
    "def greater_equal_server():\n",
    "    res_0 = temp_shared_x0 >= temp_shared_y0\n",
    "    result_restored = res_0.restore().convert_to_real_field()\n",
    "    print(\"\\n(x >= y)\", result_restored)\n",
    "\n",
    "def greater_equal_client():\n",
    "    res_1 = temp_shared_x1 >= temp_shared_y1\n",
    "    res_1.restore()\n",
    "\n",
    "server_thread = threading.Thread(target=greater_equal_server)\n",
    "client_thread = threading.Thread(target=greater_equal_client)\n",
    "\n",
    "server_thread.start()\n",
    "client_thread.start()\n",
    "client_thread.join()\n",
    "server_thread.join()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
